//ArithmeticLogicalUnit.scala
package mycore

import chisel3._
import chisel3.util._

class ArithmeticLogicalUnit extends Module{
	val io = IO(new Bundle{
		val isa = Input(new ISA)
		val src1 = Input(UInt(64.W))
		val src2 = Input(UInt(64.W))
		val imm = Input(new IMM)
		val result = Output(UInt(64.W))
	})


	/*
		手写复用器时代留下的远古编码风格。

		代码虽然丑，但是验证比较方便...
	*/

	//Arithmetic
	val addi	= SignExt(io.isa.ADDI.asUInt, 64)	& (io.src1 + io.imm.I)
	val add 	= SignExt(io.isa.ADD.asUInt, 64)	& (io.src1 + io.src2)
	val lui 	= SignExt(io.isa.LUI.asUInt, 64)	& (io.imm.U)
	val sub		= SignExt(io.isa.SUB.asUInt, 64)	& (io.src1 - io.src2)
	val addiw 	= SignExt(io.isa.ADDIW.asUInt, 64)	& SignExt((io.src1 + io.imm.I)(31,0), 64)
	val addw 	= SignExt(io.isa.ADDW.asUInt, 64)	& SignExt((io.src1 + io.src2)(31,0), 64)
	val subw 	= SignExt(io.isa.SUBW.asUInt, 64)	& SignExt((io.src1 - io.src2)(31,0), 64)
	val Arithmetic = addi | add | lui | sub | addiw | addw | subw
	//Logical
	val andi	= SignExt(io.isa.ANDI.asUInt, 64)	& (io.src1 & io.imm.I)
	val and		= SignExt(io.isa.AND.asUInt, 64)	& (io.src1 & io.src2)
	val ori		= SignExt(io.isa.ORI.asUInt, 64)	& (io.src1 | io.imm.I)
	val or		= SignExt(io.isa.OR.asUInt, 64)		& (io.src1 | io.src2)
	val xori	= SignExt(io.isa.XORI.asUInt, 64)	& (io.src1 ^ io.imm.I)
	val xor		= SignExt(io.isa.XOR.asUInt, 64)	& (io.src1 ^ io.src2)
	val Logical = andi | and | ori | or | xori | xor
	//Compare
	val slt 	= Mux((io.isa.SLT 	&& (io.src1.asSInt < io.src2.asSInt)), 	1.U(64.W), 0.U(64.W))
	val slti 	= Mux((io.isa.SLTI 	&& (io.src1.asSInt < io.imm.I.asSInt)),	1.U(64.W), 0.U(64.W))
	val sltu 	= Mux((io.isa.SLTU 	&& (io.src1.asUInt < io.src2.asUInt)), 	1.U(64.W), 0.U(64.W))
	val sltiu 	= Mux((io.isa.SLTIU	&& (io.src1.asUInt < io.imm.I.asUInt)),	1.U(64.W), 0.U(64.W))
	val Compare = slt | slti | sltu | sltiu
	//Shifts
	val sll		= SignExt(io.isa.SLL.asUInt, 64)	& (io.src1 			<< io.src2(5,0))(63,0)
	val srl		= SignExt(io.isa.SRL.asUInt, 64)	& (io.src1 			>> io.src2(5,0))
	val sra		= SignExt(io.isa.SRA.asUInt, 64)	& (io.src1.asSInt 	>> io.src2(5,0)).asUInt
	val slli	= SignExt(io.isa.SLLI.asUInt, 64)	& (io.src1 			<< io.imm.I(5,0))(63,0)
	val srli	= SignExt(io.isa.SRLI.asUInt, 64)	& (io.src1			>> io.imm.I(5,0))
	val srai	= SignExt(io.isa.SRAI.asUInt, 64)	& (io.src1.asSInt	>> io.imm.I(5,0)).asUInt
	val sllw	= SignExt(io.isa.SLLW.asUInt, 64)	& SignExt((io.src1				<< io.src2(4,0))(31,0)	, 64)
	val srlw	= SignExt(io.isa.SRLW.asUInt, 64)	& SignExt((io.src1(31,0)		>> io.src2(4,0))		, 64)
	val sraw	= SignExt(io.isa.SRAW.asUInt, 64)	& SignExt((io.src1(31,0).asSInt	>> io.src2(4,0)).asUInt	, 64)
	val slliw	= SignExt(io.isa.SLLIW.asUInt, 64)	& SignExt((io.src1 				<< io.imm.I(4,0))(31,0)	, 64)
	val srliw	= SignExt(io.isa.SRLIW.asUInt, 64)	& SignExt((io.src1(31,0)		>> io.imm.I(4,0))		, 64)
	val sraiw	= SignExt(io.isa.SRAIW.asUInt, 64)	& SignExt((io.src1(31,0).asSInt >> io.imm.I(4,0)).asUInt, 64)
	val Shifts = sll | srl | sra | slli | srli | srai | sllw | srlw | sraw | slliw | srliw | sraiw

	io.result := Arithmetic | Logical | Compare | Shifts

}
